Source code for nlp_architect.common.cdc.cluster

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
from typing import List

from nlp_architect.common.cdc.mention_data import MentionData


[docs]class Cluster(object): def __init__(self, coref_chain: int = -1) -> None: """ Object represent a set of mentions with same coref chain id Args: coref_chain (int): the cluster id/coref_chain value """ self.mentions = [] self.cluster_strings = [] self.merged = False self.coref_chain = coref_chain self.mentions_corefs = set()
[docs] def get_mentions(self): return self.mentions
[docs] def add_mention(self, mention: MentionData) -> None: if mention is not None: mention.predicted_coref_chain = self.coref_chain self.mentions.append(mention) self.cluster_strings.append(mention.tokens_str) self.mentions_corefs.add(mention.coref_chain)
[docs] def merge_clusters(self, cluster) -> None: """ Args: cluster: cluster to merge this cluster with """ for mention in cluster.mentions: mention.predicted_coref_chain = self.coref_chain self.mentions.extend(cluster.mentions) self.cluster_strings.extend(cluster.cluster_strings) self.mentions_corefs.update(cluster.mentions_corefs)
[docs] def get_cluster_id(self) -> str: """ Returns: A generated cluster unique Id created from cluster mentions ids """ return '$'.join([mention.mention_id for mention in self.mentions])
[docs]class Clusters(object): cluster_coref_chain = 1000 def __init__(self, topic_id: str, mentions: List[MentionData] = None) -> None: """ Args: mentions: ``list[MentionData]``, required The initial mentions to create the clusters from """ self.clusters_list = [] self.topic_id = topic_id self.set_initial_clusters(mentions)
[docs] def set_initial_clusters(self, mentions: List[MentionData]) -> None: """ Args: mentions: ``list[MentionData]``, required The initial mentions to create the clusters from """ if mentions: for mention in mentions: cluster = Cluster(Clusters.cluster_coref_chain) cluster.add_mention(mention) self.clusters_list.append(cluster) Clusters.cluster_coref_chain += 1
[docs] def clean_clusters(self) -> None: """ Remove all clusters that were already merged with other clusters """ self.clusters_list = [cluster for cluster in self.clusters_list if not cluster.merged]
[docs] def set_coref_chain_to_mentions(self) -> None: """ Give all cluster mentions the same coref ID as cluster coref chain ID """ for cluster in self.clusters_list: for mention in cluster.mentions: mention.predicted_coref_chain = str(cluster.coref_chain)
[docs] def add_cluster(self, cluster: Cluster) -> None: self.clusters_list.append(cluster)
[docs] def add_clusters(self, clusters) -> None: for cluster in clusters.clusters_list: self.clusters_list.append(cluster)